import numpy as np
import torch
from pathlib import Path

import models
from utils import *
import time


ARGS = {'device_bsize': 1000,
        'batch_size': 2000,
        'T': 8000,
        'lr': 0.001,
        'eps': 0.004,
        'mmt': 0.99,
        'n_m': 5}
ZERO = 1e-6

# Getting the dataset
# n_m : the ratio of n to m
# p : the portion of random labels 
def get_datasets(n_m=5, p = 0, data_path='dataset/cifar10'):
  mean = [x / 255 for x in [125.3, 123.0, 113.9]]
  std = [x / 255 for x in [63.0, 62.1, 66.7]]
  data_transform_aug = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean, std)])
  data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)])

  S_aug = dset.CIFAR10(data_path, train=True, 
    transform=data_transform_aug, download=True)
  random_label(S_aug, p)
  
  S = dset.CIFAR10(data_path, train=True,
    transform=data_transform, download=True)
  random_label(S, p)

  D = dset.CIFAR10(data_path, train=False,
    transform=data_transform, download=True)
  random_label(D, p)

  n = len(S)
  m = int(n / n_m)
  J = np.random.choice(n, m, replace=False)
  I_mask = np.ones(n, dtype=np.bool_)
  I_mask[J] = False
  I = np.arange(n)[I_mask]
  SI = torch.utils.data.Subset(S, I)
  SI_aug = torch.utils.data.Subset(S_aug, I)
  SJ = torch.utils.data.Subset(S, J)
  SJ_aug = torch.utils.data.Subset(S_aug, J)
  return SI, SI_aug, SJ, SJ_aug, D


def compute_sum_grad(model, loss_fn, lr, data_ld, g_out, bsize):
  model.zero_grad()
  cnt = 0
  for (x, y) in data_ld:
    x, y = x.cuda(), y.cuda()
    y_out = model(x)
    loss_fn(y_out, y).backward()
    cnt += len(y_out)
    if cnt >= bsize:
      break

  with torch.no_grad():
    for i, p in enumerate(model.parameters()):
      g_out[i].copy_(p.grad.data)
      g_out[i].mul_(lr)
  return cnt


def update_model(model, g):
  with torch.no_grad():
    for i, p in enumerate(model.parameters()):
      p.data.add_(g[i], alpha=-1.)


def compute_bound(n, m, T, d, eps, sum_lr2grad2):
  if eps <= ZERO:
    return -1.

  delta = 0.1
  bound = (np.log(1 / delta) + 3) / (n - m)
  bound += (1 / (eps ** 2.)) * (np.log(d) + np.log(T)) \
           * sum_lr2grad2 / (n - m) 
  return bound


def train(fname, args, SI, SI_aug, SJ, SJ_aug, D, sgd = False):
  device_bsize = args['device_bsize']
  bsize = args['batch_size']
  open(fname,'w').write(str(args) + '\n')
  m = len(SJ_aug)
  n = m + len(SI_aug)
  print(n, m)

  SI_aug_ld = get_data_ld(SI_aug, device_bsize)
  SJ_aug_ld = get_data_ld(SJ_aug, device_bsize)
  SI_ld = get_data_ld(SI, device_bsize)
  SJ_ld = get_data_ld(SJ, device_bsize)
  test_ld = get_data_ld(D, device_bsize)

  model = models.SimpleNet()
  model.cuda()
  print('device_count:', torch.cuda.device_count())
  
  # Use data parallel.
  # device0 = torch.device("cuda:0")
  # model = torch.nn.DataParallel(model)
  # model.to(device0)
  
  loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
  loss_fn.cuda()
  
  # Set model.grad to 0.
  x0, y0 = next(iter(SJ_aug_ld))
  zero_grad(model, loss_fn, x0, y0)

  # Create tensors with the same shape as model.parameters().
  p_last = clone_param(model)
  g_mmt = clone_param(model)
  g1 = clone_param(model)
  g2 = clone_param(model)
  
  d = number_of_parameter(model)
  print_log({'d': d}, fname)
  
  # Start training.
  model.train()
  lr = args['lr']
  eps = args['eps']
  mmt = args['mmt']

  sum_grad2 = sum_lr2grad2 = lr2grad2 = 0.

  for t in range(args['T']):
    print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())), t)
    # Logging Information.
    if t % 50 == 0:
      SI_acc = accuracy(SI_ld, model)
      SJ_acc = accuracy(SJ_ld, model)
      te_acc = accuracy(test_ld, model)
      tr_acc = (SI_acc * (n - m) + SJ_acc * m) / n
      bound_2 = compute_bound(n, m, t + 1, d, eps, sum_lr2grad2)
      log_info_dic = {'step' : t,
                      'tr_acc': tr_acc,
                      'te_acc': te_acc,
                      'lr2grad2': lr2grad2,
                      'bound_1' : 1. - SI_acc,
                      'bound_2': bound_2,}
      print_log(log_info_dic, fname)

    # g1 <-- grad(w, S_aug) - grad(w,SJ_aug)
    # g2 <-- grad(w, SJ_aug)
    nI = compute_sum_grad(model, loss_fn, lr, SI_aug_ld, g1, bsize)
    nJ = compute_sum_grad(model, loss_fn, lr, SJ_aug_ld, g2, bsize)
    nS = nI + nJ
    vec_add(g1, g2)
    vec_mul(g1, 1. / nS)
    vec_mul(g2, 1. / nJ)

    if eps > ZERO and (not sgd):
      # g1 <-- g2 + eps * round((g1 - g2) / eps)
      vec_sub(g1, g2)
      lr2grad2 = vec_norm2(g1)
      sum_lr2grad2 += lr2grad2
      sum_grad2 += lr2grad2 / (lr ** 2.)
      vec_mul(g1, 1. / eps)
      vec_round(g1)
      vec_mul(g1, eps)
      vec_add(g1, g2)
    else:
      # Pure SGD.
      pass
    
    if mmt > ZERO:
      # g_mmt <-- (W_{t-1} - W_{t-2}) * (-mmt)
      copy_param(g_mmt, model)
      vec_sub(g_mmt, p_last)
      vec_mul(g_mmt, -mmt)
      # g1 <-- g_mmt + g1
      vec_add(g1, g_mmt)
      copy_param(p_last, model)

    if t % 200 == 0 and t > 0:
      lr = lr * 0.9

    # W{t} <-- W{t-1} - g1
    update_model(model, g1)
  return sum_grad2


def run_training_process(rid):
  # Training by FSGD
  fname = f'log/cifar10/fsgd/{rid}.out'
  SI, SI_aug, SJ, SJ_aug, D = get_datasets(n_m=ARGS['n_m'])
  if not Path(fname).exists():
    train(fname, ARGS, SI, SI_aug, SJ, SJ_aug, D)

  # Training by SGD
  fname = f'log/cifar10/sgd/{rid}.out'
  if not Path(fname).exists():
    train(fname, ARGS, SI, SI_aug, SJ, SJ_aug, D, sgd = True)

def run_training_random_label(rid, p):
  fname = f'log/cifar10/random_label/{p}/{rid}.out'
  if Path(fname).exists():
    return
  SI, SI_aug, SJ, SJ_aug, D = get_datasets(n_m=ARGS['n_m'], p = p)
  train(fname, ARGS, SI, SI_aug, SJ, SJ_aug, D)

def study_m_graddiff(rid):
  fname = f'log/cifar10/mgrad/{rid}.out'
  if Path(fname).exists():
    return
  for _m in range(1, 10):
    m = _m * 1000
    SI, SI_aug, SJ, SJ_aug, D = get_datasets(n_m=50 / _m)
    sum_grad2 = train(f'log/cifar10/debug.out', ARGS, SI, SI_aug, SJ, SJ_aug, D)
    print_log({'m': m, 'sum_grad2': sum_grad2}, fname)

if __name__ == '__main__':
  for rid in range(0, 100):
    # FSGD vs SGD
    run_training_process(rid)

    # Random Labels
    run_training_random_label(rid, 0.1)
    run_training_random_label(rid, 0.2)
    run_training_random_label(rid, 0.5)

    # gradient difference decreases as m increases
    study_m_graddiff(rid)
